import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import sys
import random
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
from models.r2d2_2heads import R2D2Agent
from utils.memory import Memory, LocalBuffer
from tensorboardX import SummaryWriter

from models.r2d2_config import initial_exploration, batch_size, update_target, log_interval, eval_argmax, eval_interval, device, replay_memory_capacity, lr, sequence_length, local_mini_batch
from utils.pbmaze_config import iql_env_config as env_config
from phone_booth_collab_maze import PBCMaze
from collections import deque

RESULT_PATH = "results/"
MODEL_PATH = "trained_models/"
NUM_RUNS = 5

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def evaluate(eval_env, a0_agent, a1_agent):
    done = False

    score = 0
    steps = 0
    a0_reward = None
    a1_reward = None
    obs, state = eval_env.reset()

    a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_actions = []
    with torch.no_grad():
        while not done:
            # Agent 0's turn
            a0_obs = torch.Tensor(eval_env.get_obs(0)).to(device)
            a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, 0.0, a0_hidden, mode = "iql")
            a0_reward, done, info = eval_env.step(0, a0_action, a0_policy.squeeze().detach().numpy())
            a0_hidden = a0_next_hidden
            # Agent 1's turn
            a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
            a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, 0.0, a1_hidden, mode = "iql")
            a1_reward, done, info = eval_env.step(1, a1_action)
            a1_hidden = a1_next_hidden
            score += a0_reward + a1_reward
            a1_actions.append(a1_action.item())
    print(a1_actions)
    return score, info

def main():
    sender_time_to_booth_result = []
    receiver_time_to_booth_result = []
    reward_result = []
    eval_reward_result = []
    running_reward_result = []
    runnning_eval_reward_result = []
    for run_idx in range(NUM_RUNS):
        print("Run: " + str(run_idx + 1))
        # Set seed
        set_seed(run_idx)

        # Env
        num_episodes = 40000
        env = PBCMaze(env_args=env_config)
        env.reset()
        eval_env = PBCMaze(env_args=env_config)
        eval_env.reset()
        eval_env.load_env_config(env.save_env_config())
        eval_env.use_mi_shaping = False
        eval_env.use_intermediate_reward = False

        """
        Agent 0 obs: ((channel, width, height), goal feature)
        Agent 1 obs: ((channel, width, height), communication token)
        """
        a0_input_shape  = env.get_obs_size(0)
        a1_input_shape = env.get_obs_size(1)
        a0_num_actions = 7
        a1_num_actions = 5

        a0_agent = R2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), lr, batch_size, device)
        a1_agent = R2D2Agent(a1_input_shape, a1_num_actions, Memory(replay_memory_capacity), LocalBuffer(), lr, batch_size, device)

        writer = SummaryWriter('logs')

        running_score = 0
        running_eval_score = 0
        epsilon = 1.0
        steps = 0
        loss = 0
        per_run_sender_time_to_booth_list = []
        per_run_receiver_time_to_booth_list = []
        per_run_reward = []
        per_run_eval_reward = []
        per_run_running_reward = []
        per_run_running_eval_reward = []
        for e in range(num_episodes):
            done = False

            score = 0
            a0_reward = None
            a1_reward = None
            obs, state = env.reset()

            a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))

            while not done:
                steps += 1

                # Agent 0's turn
                a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                _, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, epsilon, a0_hidden, mode = "iql")
                a0_reward, done, info = env.step(0, a0_action)

                if(a1_reward != None):
                    # Add to agent 1's buffer
                    mask = 0 if done else 1
                    next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                    a1_agent.local_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward + a0_reward, mask, a1_hidden)
                    a1_hidden = a1_next_hidden
                    if len(a1_agent.local_buffer.memory) == local_mini_batch:
                        a1_agent.push_to_memory()

                if steps > initial_exploration and len(a1_agent.memory) > batch_size:
                    loss, td_error = a1_agent.train_model()

                    if steps % update_target == 0:
                        a1_agent.update_target_model()

                # Agent 1's turn
                a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                _, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, epsilon, a1_hidden, mode = "iql")
                a1_reward, done, info = env.step(1, a1_action)

                # Add to agent 0's buffer
                mask = 0 if done else 1
                next_a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                a0_agent.local_buffer.push(a0_obs, next_a0_obs, a0_action, a0_reward + a1_reward, mask, a0_hidden)
                a0_hidden = a0_next_hidden
                if len(a0_agent.local_buffer.memory) == local_mini_batch:
                    a0_agent.push_to_memory()

                if(done):
                    # Need to add to a1's buffer
                    next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                    a1_agent.local_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                    a1_hidden = a1_next_hidden
                    if len(a1_agent.local_buffer.memory) == local_mini_batch:
                        a1_agent.push_to_memory()

                score += a0_reward + a1_reward

                if steps > initial_exploration and len(a0_agent.memory) > batch_size:
                    epsilon -= 0.00001
                    epsilon = max(epsilon, 0.1)
                    loss, td_error = a0_agent.train_model()
                    if steps % update_target == 0:
                        a0_agent.update_target_model()


            running_score = 0.99 * running_score + 0.01 * score
            # Steps to phone booth
            if(eval_argmax):
                if e % eval_interval == 0:
                    eval_score, info = evaluate(eval_env, a0_agent, a1_agent)
                    running_eval_score = 0.99 * running_eval_score + 0.01 * eval_score
                    sender_time_to_pb = info["sender_time_to_booth"]
                    receiver_time_to_pb = info["receiver_time_to_booth"]
                    per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                    per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                    per_run_eval_reward.append(eval_score)
                    per_run_running_eval_reward.append(running_eval_score)
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f} | epsilon: {:.2f}'.format(
                        run_idx + 1, e, running_eval_score, eval_score, sender_time_to_pb, receiver_time_to_pb, epsilon))
                    sys.stdout.flush()
            else:
                sender_time_to_pb = info["sender_time_to_booth"]
                receiver_time_to_pb = info["receiver_time_to_booth"]
                per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                per_run_reward.append(score)
                per_run_running_reward.append(running_score)
                if e % log_interval == 0:
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f} | epsilon: {:.2f}'.format(
                        run_idx + 1, e, running_score, score, sender_time_to_pb, receiver_time_to_pb, epsilon))
                    writer.add_scalar('log/score', float(running_score), e)
                    writer.add_scalar('log/loss', float(loss), e)
                    sys.stdout.flush()

        sender_time_to_booth_result.append(per_run_sender_time_to_booth_list)
        receiver_time_to_booth_result.append(per_run_receiver_time_to_booth_list)
        if(eval_argmax):
            eval_reward_result.append(per_run_eval_reward)
            runnning_eval_reward_result.append(per_run_running_eval_reward)
        else:
            reward_result.append(per_run_reward)
            running_reward_result.append(per_run_running_reward)


    # Save results
    if not os.path.exists(RESULT_PATH):
        os.makedirs(RESULT_PATH)
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)

    if(env_config['use_mi_shaping']):
        sender_result_filename = "iql_sender_time_to_pb" + "_mi_log2" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "iql_receiver_time_to_pb" + "_mi_log2" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "iql_reward" + "_mi_log2" + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "iql_running_reward" + "_mi_log2" + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "iql_sender_model " + "_mi_log2" + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "iql_receiver_model " + "_mi_log2" + ("_argmax" if eval_argmax else "")

    elif(env_config['use_intermediate_reward']):
        sender_result_filename = "iql_sender_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "iql_receiver_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "iql_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "iql_running_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "iql_sender_model " + "_ir" + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "iql_receiver_model " + "_ir" + ("_argmax" if eval_argmax else "")

    else:
        sender_result_filename = "iql_sender_time_to_pb" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "iql_receiver_time_to_pb" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "iql_reward" + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "iql_running_reward" + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "iql_sender_model " + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "iql_receiver_model " + ("_argmax" if eval_argmax else "")

    if(eval_argmax):
        np.save(RESULT_PATH + reward_result_filename, np.array(eval_reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(runnning_eval_reward_result))
    else:
        np.save(RESULT_PATH + reward_result_filename, np.array(reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(running_reward_result))

    np.save(RESULT_PATH + sender_result_filename, np.array(sender_time_to_booth_result))
    np.save(RESULT_PATH + receiver_result_filename, np.array(receiver_time_to_booth_result))

    # Save model
    a0_agent.save_model(sender_model_path)
    a1_agent.save_model(receiver_model_path)


if __name__=="__main__":
    main()
